#coding=utf-8

# Name: debugger
# Date: 12:26 2024/1/8
# Note: 

import os
import sys
import glob
import time
import msvcrt
import queue
import threading

import winreg
import win32api
import win32gui
from win32con import WM_INPUTLANGCHANGEREQUEST

import mysys
import myserial
import myconvert

## size of shell cmd buffer

SHELL_CMD_SIZE = 128

## size of download package

DOWNLOAD_PKG_SIZE = 64

## number of serial ports while scan

SERIAL_PORT_NUM = 30

## baudrate of serial port

SERIAL_BAUDRATE = 115200

## valid value of 'shell_stat'

STAT_WAIT_SPEC_KEY = 0
STAT_WAIT_FUNC_KEY = 1

## valid value of 'shell_mode'

MODE_MAX = 2

MODE_NORMAL = 0
MODE_CONTROL = 1

'''
debugger.py
'''

welcome_info = """This is a PuSH Terminal.
Type 'help' for more information.
"""

help_info = """Command:
----------------------------------------
~           -   Switch mode.
quit        -   (q) Quit terminal.
repeat      -   (r) Repeat last command.
help        -   (h) Show help menu.
cd          -   Switch to work path.
load        -   (ld) Load file.
download    -   (d) Download to target device.
go          -   (g) Control target device jump and run.
com         -   Scan valid serial port.
com*        -   Open target serial port.
"""

example_info = """Example:
----------------------------------------
1. download & run
   1) switch to control mode
   2) ld > com > com* > [hardware reset] > [ESC] > d
2. go without download
   1) switch to control mode
   2) [hardware reset] > g
"""


class GlobalParam:
    exitFlag = False

    workPath = None

    binFile = None
    binData = None
    binSize = 0

    workQueue = None
    queueLock = None

    currentSerial = None
    lastSerial = None

    threadInput = None
    threadOutput = None

    def __init__(self):
        self.workQueue = queue.Queue(SHELL_CMD_SIZE)
        self.queueLock = threading.Lock()

    def serial_check(self):
        if self.currentSerial != None or self.lastSerial != None:
            return True
        else:
            return False

    def get_desktop_path(self):
        key = winreg.OpenKey(winreg.HKEY_CURRENT_USER, r'Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders')
        return winreg.QueryValueEx(key, 'Desktop')[0]

    def get_work_path(self):
        return r'..\..'

    def cleanup(self):
        while self.serial_check():
            pass


class NewThread (threading.Thread):
    def __init__(self, *args, **keywords):
        threading.Thread.__init__(self, *args, **keywords)
        self.killed = False

    def start(self):
        self.__run_backup = self.run
        self.run = self.__run
        threading.Thread.start(self)

    def __run(self):
        sys.settrace(self.globaltrace)
        self.__run_backup()
        self.run = self.__run_backup

    def globaltrace(self, frame, event, arg):
        if event == 'call':
            return self.localtrace
        else:
            return None

    def localtrace(self, frame, event, arg):
        if self.killed:
            if event == 'line':
                raise SystemExit()
        return self.localtrace

    def kill(self):
        self.killed = True

class Shell:
    shell_trig = False
    shell_mode = MODE_CONTROL
    shell_stat = STAT_WAIT_SPEC_KEY
    shell_line = None
    shell_line_curpos = 0

    def __init__(self, cmdsize):
        self.shell_line = [0]*cmdsize

    def output(self, ser, ch, redirect=True):
        ## handle control key
        if self.shell_stat == STAT_WAIT_SPEC_KEY and ch == 0xe0:
            self.shell_stat = STAT_WAIT_FUNC_KEY
            return
        elif self.shell_stat == STAT_WAIT_FUNC_KEY:
            self.shell_stat = STAT_WAIT_SPEC_KEY
            return

        ## handle tab key
        if ch == ord('\t'):
            return
        ## handle backspace key
        elif ch == 0x08 or ch == 0x7f:
            if self.shell_line_curpos == 0:
                pass
            else:
                self.shell_line_curpos = self.shell_line_curpos - 1
                self.shell_line[self.shell_line_curpos] = 0
                _xputs(ser, "\b \b")
            return
        ## handle enter key
        elif ch == ord('\r'):
            if redirect:
                _xputs(ser, "\n")
            return

        ## handle normal character
        self.shell_line[self.shell_line_curpos] = ch
        _xputs(ser, chr(ch))

        ## handle overwrite
        self.shell_line_curpos = self.shell_line_curpos + 1
        if self.shell_line_curpos >= SHELL_CMD_SIZE:
            self.shell_line_curpos = 0


def _is_equal(v1, v2):
    rslt = False
    if type(v2) != type(list()):
        rslt = mysys._is_equal(v1, v2)
    else:
        for v in v2:
            if mysys._is_equal(v1, v):
                rslt = True
                break
    return rslt

def _xputs(ser, info):
    if ser == None:
        mysys._xputs(info)
    else:
        for ch in info:
            myserial.WriteData(ser, myconvert.ToHexStr([ord(ch),]))

def _tostr(data):
    return "".join(chr(i) for i in data).lower()


def serial_thread(port):
    global globalParam

    ser = myserial.OpenPort(port, baud_rate=SERIAL_BAUDRATE)
    globalParam.currentSerial = ser
    globalParam.lastSerial = ser

    shell = Shell(SHELL_CMD_SIZE)

    try:
        while not globalParam.exitFlag:
            (state, length, data) = myserial.ReadData(ser)
            if not state:
                break
            if length > 0:
                for ch in data:
                    shell.output(None, ch, False)
            else:
                time.sleep(0.02)
    finally:
        myserial.ClosePort(ser)
        globalParam.currentSerial = None
        globalParam.lastSerial = None

def input_thread():
    global globalParam

    while not globalParam.exitFlag:
        if msvcrt.kbhit():
            globalParam.queueLock.acquire()
            if not globalParam.workQueue.full():
                globalParam.workQueue.put( msvcrt.getch() )
            globalParam.queueLock.release()
        time.sleep(0.02)

def output_thread():
    global globalParam

    ## searh serial ports but not show them
    serial_ports = myserial.SearchPort(SERIAL_PORT_NUM)

    ## only one serial thread is allowed to be created
    tserial = None

    ## save the last command
    last_cmd_line = "help"

    ## create shell and work on control mode
    shell = Shell(SHELL_CMD_SIZE)
    _xputs(None, ":")

    while not globalParam.exitFlag:
        ## try to get one character
        if not shell.shell_trig:
            globalParam.queueLock.acquire()
            ch = -1
            if not globalParam.workQueue.empty():
                ch = int.from_bytes(globalParam.workQueue.get(), byteorder='big')
            globalParam.queueLock.release()

            if ch < 0:
                time.sleep(0.02)
                continue

        ## handle esc key
        if shell.shell_trig or ch == 0x1b:
            shell.shell_trig = False
            shell.shell_line_curpos = 0
            shell.shell_mode = (shell.shell_mode + 1) % MODE_MAX
            if shell.shell_mode == MODE_NORMAL:
                _xputs(None, "enter normal mode ...\n")
                if globalParam.serial_check():
                    _xputs(None, "serial port is alive.\n")
                globalParam.currentSerial = globalParam.lastSerial
            if shell.shell_mode == MODE_CONTROL:
                _xputs(None, ":")
                globalParam.currentSerial = None
            continue

        if shell.shell_mode == MODE_CONTROL:
            if ch == ord('\r') or ch == ord('\n'):
                cmd_line = _tostr(shell.shell_line[:shell.shell_line_curpos])
                shell.shell_line_curpos = 0
            ## USER COMMAND PARSE AND EXECUTE BEGIN
                if _is_equal(cmd_line, ['quit', 'q']):
                    _xputs(None, "\n")
                    if globalParam.serial_check():
                        tserial.kill()
                        tserial.join()
                    globalParam.exitFlag = True
                    continue

                if _is_equal(cmd_line, ['repeat', 'r']):
                    _xputs(None, ":" + last_cmd_line)
                    cmd_line = last_cmd_line

                last_cmd_line = cmd_line
                _xputs(None, "\n")

                if _is_equal(cmd_line, ['help', 'h']):
                    _xputs(None, help_info + "\n")
                    _xputs(None, example_info + "\n")
                    pass

                if _is_equal(cmd_line.split(' ')[0], 'cd'):
                    globalParam.workPath = globalParam.get_desktop_path()
                    flist = glob.glob(globalParam.workPath + '\\*.bin')
                    _xputs(None, "work path: " + globalParam.workPath + "\n")
                    _xputs(None, "find {0} valid file: ".format(len(flist)) + "\n")
                    for file in flist:
                        _xputs(None, file.split('\\')[-1] + "\n")
                    _xputs(None, "\n")
                    pass

                if _is_equal(cmd_line.split(' ')[0], ['load', 'ld']): ## 'load [xxx.bin]'
                    params = cmd_line.split(' ')
                    file = None
                    if len(params) > 1:
                        if len(params[1].split('.')) > 1:
                            file = params[1]
                        else:
                            file = params[1].split('.')[0] + '.bin'
                    else:
                        flist = glob.glob(globalParam.workPath + '\\*.bin')
                        if len(flist) > 0:
                            file = flist[0].split('\\')[-1]
                    if file != None:
                        _xputs(None, "open file: " + file + "\n")
                        try:
                            globalParam.binFile = globalParam.workPath + '\\' + file
                            with open(globalParam.binFile, 'rb') as fd:
                                globalParam.binData = fd.read()
                            globalParam.binSize = len(globalParam.binData)
                            _xputs(None, "read {0} bytes.".format(globalParam.binSize) + "\n")
                        except:
                            _xputs(None, "open failed." + "\n")
                    else:
                        _xputs(None, "invalid param." + "\n")
                    _xputs(None, "\n")
                    pass

                if _is_equal(cmd_line.split(' ')[0], ['download', 'd']): ## 'download [com*]'
                    params = cmd_line.split(' ')
                    ser = None
                    if len(params) > 1:
                        ser = myserial.OpenPort(params[1], baud_rate=SERIAL_BAUDRATE)
                        needClose = True
                    else:
                        if globalParam.serial_check():
                            ser = globalParam.lastSerial
                        needClose = False
                    if globalParam.binSize > 0:
                        pkgs = 0
                        if ser != None:
                            bindata = globalParam.binData
                            while True:
                                pkgs = pkgs + 1
                                if len(bindata) >= DOWNLOAD_PKG_SIZE:
                                    data = bindata[:DOWNLOAD_PKG_SIZE]
                                    bindata = bindata[DOWNLOAD_PKG_SIZE:]
                                    data = b'\xa5' + data +  b'\x5a'
                                    myserial.WriteData(ser, data)
                                    _xputs(None, ".")
                                    continue
                                else:
                                    data = b'\xa5' + bindata + b'\x00'*(DOWNLOAD_PKG_SIZE-len(bindata)) + b'\x5a'
                                    myserial.WriteData(ser, data)
                                    _xputs(None, ".")
                                    break
                            data = b'\xa6' + b'\x01' + b'\x00'*(DOWNLOAD_PKG_SIZE-1) + b'\x6a'
                            myserial.WriteData(ser, data)
                            _xputs(None, "." + "\n")
                            _xputs(None, "send {0} pkgs, pkg size {1}.".format(pkgs, DOWNLOAD_PKG_SIZE) + "\n")
                            if needClose:
                                myserial.ClosePort(ser)
                            shell.shell_trig = True
                        else:
                            _xputs(None, "invalid param." + "\n")
                    else:
                        _xputs(None, "please load bin file first." + "\n")
                    _xputs(None, "\n")
                    pass

                if _is_equal(cmd_line.split(' ')[0], ['go', 'g']): ## 'go [com*]'
                    params = cmd_line.split(' ')
                    ser = None
                    if len(params) > 1:
                        ser = myserial.OpenPort(params[1], baud_rate=SERIAL_BAUDRATE)
                        needClose = True
                    else:
                        if globalParam.serial_check():
                            ser = globalParam.lastSerial
                        needClose = False
                    if ser != None:
                        data = b'\xa6' + b'\x02' + b'\x00'*(DOWNLOAD_PKG_SIZE-1) + b'\x6a'
                        myserial.WriteData(ser, data)
                        _xputs(None, "." + "\n")
                        _xputs(None, "send {0} pkgs, pkg size {1}.".format(1, DOWNLOAD_PKG_SIZE) + "\n")
                        if needClose:
                            myserial.ClosePort(ser)
                        shell.shell_trig = True
                    else:
                        _xputs(None, "invalid param." + "\n")
                    _xputs(None, "\n")
                    pass

                if _is_equal(cmd_line, 'com'):
                    serial_ports = myserial.SearchPort(SERIAL_PORT_NUM)
                    _xputs(None, "valid port: " + str(serial_ports).replace('[','').replace(']','').replace('\'','') + "\n")
                    _xputs(None, "\n")
                    pass

                if serial_ports != None: ## 'com*'
                    for i in range(len(serial_ports)):
                        if _is_equal(cmd_line, serial_ports[i]): # e.g. _is_equal(cmd_line, 'com20')
                            if globalParam.serial_check():
                                tserial.kill()
                                tserial.join()
                            tserial = NewThread(target=serial_thread, args=(serial_ports[i],))
                            tserial.start()
                            while not globalParam.serial_check():
                                pass
                            shell.shell_trig = True
                    pass

                if shell.shell_trig:
                    continue

                _xputs(None, ":")
            ## USER COMMAND PARSE AND EXECUTE END
                continue

        if shell.shell_mode == MODE_CONTROL or shell.shell_mode == MODE_NORMAL:
            shell.output(globalParam.currentSerial, ch, True)

if __name__ == "__main__":
    global globalParam
    globalParam = GlobalParam()

    mysys.env_init(lang='EN', cls=True, curs=False)

    print(welcome_info, end="\n")

    globalParam.workPath = globalParam.get_work_path()

    globalParam.threadInput = NewThread(target=input_thread)
    globalParam.threadOutput = NewThread(target=output_thread)

    globalParam.threadInput.start()
    globalParam.threadOutput.start()

    globalParam.threadInput.join()
    globalParam.threadOutput.join()

    globalParam.cleanup()

    timeout = 0
    print("Exit After %d Second ..." % timeout, end="\n")
    time.sleep(timeout)

    mysys.env_init(lang='ZH', cls=False, curs=True)

    print("Exit.", end="\n")
